/*
 * Copyright (C) Anton Naruta && Daniel Hoppener
 * MAVLab Delft University of Technology
 *
 * This file is part of paparazzi.
 *
 * paparazzi is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2, or (at your option)
 * any later version.
 *
 * paparazzi is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with paparazzi; see the file COPYING.  If not, write to
 * the Free Software Foundation, 59 Temple Place - Suite 330,
 * Boston, MA 02111-1307, USA.
 */

/** @file wls_alloc.c
 * @brief This is an active set algorithm for WLS control allocation
 *
 * This algorithm will find the optimal inputs to produce the least error wrt
 * the control objective, taking into account the weighting matrices on the
 * control objective and the control effort.
 *
 * The algorithm is described in:
 * Prioritized Control Allocation for Quadrotors Subject to Saturation -
 * E.J.J. Smeur, D.C. Höppener, C. de Wagter. In IMAV 2017
 *
 * written by Anton Naruta && Daniel Hoppener 2016
 * MAVLab Delft University of Technology
 */

#include "wls_alloc.h"
#include "std.h"

#include <string.h>
#include <math.h>
#include <float.h>
#include "math/qr_solve/qr_solve.h"
#include "math/qr_solve/r8lib_min.h"

// provide loop feedback
#ifndef WLS_VERBOSE
#define WLS_VERBOSE FALSE
#endif

#if WLS_VERBOSE
#include <stdio.h>
static void print_final_values(struct WLS_t* WLS_p, float **B);
static void print_in_and_outputs(int n_c, int n_free, float **A_free_ptr, float *d, float *p_free);
#endif

/* Define messages of the module*/
#if PERIODIC_TELEMETRY
#include "modules/datalink/telemetry.h"
void send_wls_v(char *name, struct WLS_t *WLS_p, struct transport_tx *trans, struct link_device *dev)
{
  uint8_t iter_temp = (uint8_t)WLS_p->iter;
  pprz_msg_send_WLS_V(trans, dev, AC_ID,
                      strlen(name),name,
                      &WLS_p->gamma_sq,
                      &iter_temp,
                      WLS_p->nv, WLS_p->v,
                      WLS_p->nv, WLS_p->Wv);
}
void send_wls_u(char *name, struct WLS_t *WLS_p, struct transport_tx *trans, struct link_device *dev)
{
  pprz_msg_send_WLS_U(trans, dev, AC_ID,
                      strlen(name),name,
                      WLS_p->nu, WLS_p->Wu,
                      WLS_p->nu, WLS_p->u_pref,
                      WLS_p->nu, WLS_p->u_min,
                      WLS_p->nu, WLS_p->u_max,
                      WLS_p->nu, WLS_p->u);
}
#endif

/**
 * @brief Wrapper for qr solve
 *
 * Possible to use a different solver if needed.
 * Solves a system of the form Ax = b for x.
 *
 * @param m number of rows
 * @param n number of columns
 */
static void qr_solve_wrapper(int m, int n, float **A, float *b, float *x) {
  float in[m * n];
  // convert A to 1d array
  int k = 0;
  for (int j = 0; j < n; j++) {
    for (int i = 0; i < m; i++) {
      in[k++] = A[i][j];
    }
  }
  // use solver
  qr_solve(m, n, in, b, x);
}

/**
 * @brief active set algorithm for control allocation
 *
 * Takes the control objective and max and min inputs from pprz and calculates
 * the inputs that will satisfy most of the control objective, subject to the
 * weighting matrices Wv and Wu
 *  
 * @param WLS_p Struct that contains most of the WLS parameters
 * @param B The control effectiveness matrix
 * @param u_guess Initial value for u
 * @param W_init Initial working set, if known
 * @param imax Max number of iterations
 */

void wls_alloc(struct WLS_t* WLS_p, float **B, float *u_guess, float *W_init, int imax) {
  // allocate variables, use defaults where parameters are set to 0
  if (!WLS_p->gamma_sq) WLS_p->gamma_sq = 100000;
  if (!imax) imax = 100;

  int n_c = WLS_p->nu + WLS_p->nv;

  float A[n_c][WLS_p->nu];
  float A_free[n_c][WLS_p->nu];

  // Create a pointer array to the rows of A_free
  // such that we can pass it to a function
  float *A_free_ptr[n_c];
  for(int i = 0; i < n_c; i++)
    A_free_ptr[i] = A_free[i];

  float b[n_c];
  float d[n_c];

  int free_index[WLS_p->nu];
  int free_index_lookup[WLS_p->nu];
  int n_free = 0;
  int free_chk = -1;

  int iter = 0;
  float p_free[WLS_p->nu];
  float p[WLS_p->nu];
  float u_opt[WLS_p->nu];
  int infeasible_index[WLS_p->nu] UNUSED;
  int n_infeasible = 0;
  float lambda[WLS_p->nu];
  float W[WLS_p->nu];

  // Initialize u and the working set, if provided from input
  if (!u_guess) {
    for (int i = 0; i < WLS_p->nu; i++) {
      WLS_p->u[i] = (WLS_p->u_max[i] + WLS_p->u_min[i]) * 0.5;
    }
  } else {
    for (int i = 0; i < WLS_p->nu; i++) {
      WLS_p->u[i] = u_guess[i];
    }
  }
  W_init ? memcpy(W, W_init, WLS_p->nu * sizeof(float))
         : memset(W, 0, WLS_p->nu * sizeof(float));

  memset(free_index_lookup, -1, WLS_p->nu * sizeof(float));

  // find free indices
  for (int i = 0; i < WLS_p->nu; i++) {
    if (W[i] == 0) {
      free_index_lookup[i] = n_free;
      free_index[n_free++] = i;
    }
  }

  // fill up A, A_free, b and d
  for (int i = 0; i < WLS_p->nv; i++) {
    b[i] = WLS_p->gamma_sq * WLS_p->Wv[i] * WLS_p->v[i];
    d[i] = b[i];
    for (int j = 0; j < WLS_p->nu; j++) {
      // If Wv is a NULL pointer, use Wv = identity
      A[i][j] = WLS_p->gamma_sq * WLS_p->Wv[i] * B[i][j];
      d[i] -= A[i][j] * WLS_p->u[j];
    }
  }
  for (int i = WLS_p->nv; i < n_c; i++) {
    memset(A[i], 0, WLS_p->nu * sizeof(float));
    A[i][i - WLS_p->nv] = WLS_p->Wu[i - WLS_p->nv];
    b[i] = WLS_p->Wu[i - WLS_p->nv] * WLS_p->u_pref[i - WLS_p->nv];
    d[i] = b[i] - A[i][i - WLS_p->nv] * WLS_p->u[i - WLS_p->nv];
  }

  // -------------- Start loop ------------
  while (iter++ < imax) {
    // clear p, copy u to u_opt
    memset(p, 0, WLS_p->nu * sizeof(float));
    memcpy(u_opt, WLS_p->u, WLS_p->nu * sizeof(float));

    // Construct a matrix with the free columns of A
    if (free_chk != n_free) {
      for (int i = 0; i < n_c; i++) {
        for (int j = 0; j < n_free; j++) {
          A_free[i][j] = A[i][free_index[j]];
        }
      }
      free_chk = n_free;
    }


    // Count the infeasible free actuators
    n_infeasible = 0;

    if (n_free > 0) {
      // Still free variables left, calculate corresponding solution

      // use a solver to find the solution to A_free*p_free = d
      qr_solve_wrapper(n_c, n_free, A_free_ptr, d, p_free);

      //print results current step
#if WLS_VERBOSE
      print_in_and_outputs(n_c, n_free, A_free_ptr, d, p_free);
#endif

      // Set the nonzero values of p and add to u_opt
      for (int i = 0; i < n_free; i++) {
        p[free_index[i]] = p_free[i];
        u_opt[free_index[i]] += p_free[i];

        // check limits
        if ((u_opt[free_index[i]] > WLS_p->u_max[free_index[i]] || u_opt[free_index[i]] < WLS_p->u_min[free_index[i]])) {
          infeasible_index[n_infeasible++] = free_index[i];
        }
      }
    }

    // Check feasibility of the solution
    if (n_infeasible == 0) {
      // all variables are within limits
      memcpy(WLS_p->u, u_opt, WLS_p->nu * sizeof(float));
      memset(lambda, 0, WLS_p->nu * sizeof(float));

      // d = d + A_free*p_free; lambda = A*d;
      for (int i = 0; i < n_c; i++) {
        for (int k = 0; k < n_free; k++) {
          d[i] -= A_free[i][k] * p_free[k];
        }
        for (int k = 0; k < WLS_p->nu; k++) {
          lambda[k] += A[i][k] * d[i];
        }
      }
      bool break_flag = true;

      // lambda = lambda x W;
      for (int i = 0; i < WLS_p->nu; i++) {
        lambda[i] *= W[i];
        // if any lambdas are negative, keep looking for solution
        if (lambda[i] < -FLT_EPSILON) {
          break_flag = false;
          W[i] = 0;
          // add a free index
          if (free_index_lookup[i] < 0) {
            free_index_lookup[i] = n_free;
            free_index[n_free++] = i;
          }
        }
      }
      if (break_flag) {

#if WLS_VERBOSE
        print_final_values(WLS_p, B);
#endif
        
        // if solution is found, return number of iterations
        WLS_p->iter = iter;
      }
    } else {
      // scaling back actuator command (0-1)
      float alpha = 1.0;
      float alpha_tmp;
      int id_alpha = free_index[0];

      // find the lowest distance from the limit among the free variables
      for (int i = 0; i < n_free; i++) {
        int id = free_index[i];

        alpha_tmp = (p[id] < 0) ? (WLS_p->u_min[id] - WLS_p->u[id]) / p[id]
                                : (WLS_p->u_max[id] - WLS_p->u[id]) / p[id];

        if (isnan(alpha_tmp) || alpha_tmp < 0.f) {
          alpha_tmp = 1.0f;
        }
        if (alpha_tmp < alpha) {
          alpha = alpha_tmp;
          id_alpha = id;
        }
      }

      // update input u = u + alpha*p
      for (int i = 0; i < WLS_p->nu; i++) {
        WLS_p->u[i] += alpha * p[i];
        Bound(WLS_p->u[i], WLS_p->u_min[i], WLS_p->u_max[i]);
      }
      // update d = d-alpha*A*p_free
      for (int i = 0; i < n_c; i++) {
        for (int k = 0; k < n_free; k++) {
          d[i] -= A_free[i][k] * alpha * p_free[k];
        }
      }
      // get rid of a free index
      W[id_alpha] = (p[id_alpha] > 0) ? 1.0 : -1.0;

      free_index[free_index_lookup[id_alpha]] = free_index[--n_free];
      free_index_lookup[free_index[free_index_lookup[id_alpha]]] =
          free_index_lookup[id_alpha];
      free_index_lookup[id_alpha] = -1;
    }
  }
  WLS_p->iter = iter;
}

#if WLS_VERBOSE
static void print_in_and_outputs(int n_c, int n_free, float **A_free_ptr, float *d, float *p_free) {
  printf("n_c = %d n_free = %d\n", n_c, n_free);

  printf("A_free =\n");
  for (int i = 0; i < n_c; i++) {
    for (int j = 0; j < n_free; j++) {
      printf("%f ", A_free_ptr[i][j]);
    }
    printf("\n");
  }

  printf("d = ");
  for (int j = 0; j < n_c; j++) {
    printf("%f ", d[j]);
  }

  printf("\noutput = ");
  for (int j = 0; j < n_free; j++) {
    printf("%f ", p_free[j]);
  }
  printf("\n\n");
}

static void print_final_values(struct WLS_t* WLS_p, float **B) {
  printf("n_u = %d n_v = %d\n", WLS_p->nu, WLS_p->nv);

  printf("B =\n");
  for (int i = 0; i < WLS_p->nv; i++) {
    for (int j = 0; j < WLS_p->nu; j++) {
      printf("%f ", B[i][j]);
    }
    printf("\n");
  }

  printf("v = ");
  for (int j = 0; j < WLS_p->nv; j++) {
    printf("%f ", WLS_p->v[j]);
  }

  printf("\nu = ");
  for (int j = 0; j < WLS_p->nu; j++) {
    printf("%f ", u[j]);
  }
  printf("\n");

  printf("\numin = ");
  for (int j = 0; j < WLS_p->nu; j++) {
    printf("%f ", WLS_p->u_min[j]);
  }
  printf("\n");

  printf("\numax = ");
  for (int j = 0; j < WLS_p->nu; j++) {
    printf("%f ", WLS_p->u_max[j]);
  }
  printf("\n\n");
  
}
#endif
